#[burn_tensor_testgen::testgen(module_conv2d)]
mod tests {
    use super::*;
    use burn_tensor::module::conv2d;
    use burn_tensor::ops::ConvOptions;
    use burn_tensor::{Shape, Tensor};

    #[test]
    fn test_conv2d_simple() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [1196., 1796., 1916., 1264.],
                [1881., 2793., 2946., 1923.],
                [2313., 3405., 3558., 2307.],
                [1424., 2072., 2156., 1380.],
            ],
            [
                [2709., 4173., 4509., 3065.],
                [4582., 7006., 7483., 5056.],
                [5878., 8914., 9391., 6304.],
                [4089., 6177., 6477., 4333.],
            ],
        ]]));
    }

    #[test]
    fn test_conv2d_simple_implicit() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 1,
            channels_out: 16,
            kernel_size_1: 4,
            kernel_size_2: 4,
            padding_1: 1,
            padding_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 5,
            width: 5,
        };

        test.assert_output(TestTensor::from([[
            [
                [666., 916., 1030., 774.],
                [1124., 1500., 1620., 1190.],
                [1604., 2100., 2220., 1610.],
                [990., 1264., 1330., 936.],
            ],
            [
                [1531., 2165., 2471., 1927.],
                [2757., 3805., 4181., 3207.],
                [4197., 5685., 6061., 4587.],
                [3295., 4433., 4691., 3529.],
            ],
            [
                [2396., 3414., 3912., 3080.],
                [4390., 6110., 6742., 5224.],
                [6790., 9270., 9902., 7564.],
                [5600., 7602., 8052., 6122.],
            ],
            [
                [3261., 4663., 5353., 4233.],
                [6023., 8415., 9303., 7241.],
                [9383., 12855., 13743., 10541.],
                [7905., 10771., 11413., 8715.],
            ],
            [
                [4126., 5912., 6794., 5386.],
                [7656., 10720., 11864., 9258.],
                [11976., 16440., 17584., 13518.],
                [10210., 13940., 14774., 11308.],
            ],
            [
                [4991., 7161., 8235., 6539.],
                [9289., 13025., 14425., 11275.],
                [14569., 20025., 21425., 16495.],
                [12515., 17109., 18135., 13901.],
            ],
            [
                [5856., 8410., 9676., 7692.],
                [10922., 15330., 16986., 13292.],
                [17162., 23610., 25266., 19472.],
                [14820., 20278., 21496., 16494.],
            ],
            [
                [6721., 9659., 11117., 8845.],
                [12555., 17635., 19547., 15309.],
                [19755., 27195., 29107., 22449.],
                [17125., 23447., 24857., 19087.],
            ],
            [
                [7586., 10908., 12558., 9998.],
                [14188., 19940., 22108., 17326.],
                [22348., 30780., 32948., 25426.],
                [19430., 26616., 28218., 21680.],
            ],
            [
                [8451., 12157., 13999., 11151.],
                [15821., 22245., 24669., 19343.],
                [24941., 34365., 36789., 28403.],
                [21735., 29785., 31579., 24273.],
            ],
            [
                [9316., 13406., 15440., 12304.],
                [17454., 24550., 27230., 21360.],
                [27534., 37950., 40630., 31380.],
                [24040., 32954., 34940., 26866.],
            ],
            [
                [10181., 14655., 16881., 13457.],
                [19087., 26855., 29791., 23377.],
                [30127., 41535., 44471., 34357.],
                [26345., 36123., 38301., 29459.],
            ],
            [
                [11046., 15904., 18322., 14610.],
                [20720., 29160., 32352., 25394.],
                [32720., 45120., 48312., 37334.],
                [28650., 39292., 41662., 32052.],
            ],
            [
                [11911., 17153., 19763., 15763.],
                [22353., 31465., 34913., 27411.],
                [35313., 48705., 52153., 40311.],
                [30955., 42461., 45023., 34645.],
            ],
            [
                [12776., 18402., 21204., 16916.],
                [23986., 33770., 37474., 29428.],
                [37906., 52290., 55994., 43288.],
                [33260., 45630., 48384., 37238.],
            ],
            [
                [13641., 19651., 22645., 18069.],
                [25619., 36075., 40035., 31445.],
                [40499., 55875., 59835., 46265.],
                [35565., 48799., 51745., 39831.],
            ],
        ]]));
    }

    #[test]
    fn test_conv2d_implicit_padded_in_channels() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 3,
            channels_out: 16,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 1,
            padding_2: 1,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 1,
            height: 4,
            width: 4,
        };

        test.assert_output(TestTensor::from([[
            [
                [4521., 6753., 7014., 4635.],
                [6858., 10197., 10548., 6939.],
                [7830., 11601., 11952., 7839.],
                [5007., 7383., 7590., 4953.],
            ],
            [
                [10516., 15988., 16735., 11278.],
                [16822., 25507., 26587., 17875.],
                [19738., 29827., 30907., 20719.],
                [13594., 20506., 21199., 14188.],
            ],
            [
                [16511., 25223., 26456., 17921.],
                [26786., 40817., 42626., 28811.],
                [31646., 48053., 49862., 33599.],
                [22181., 33629., 34808., 23423.],
            ],
            [
                [22506., 34458., 36177., 24564.],
                [36750., 56127., 58665., 39747.],
                [43554., 66279., 68817., 46479.],
                [30768., 46752., 48417., 32658.],
            ],
            [
                [28501., 43693., 45898., 31207.],
                [46714., 71437., 74704., 50683.],
                [55462., 84505., 87772., 59359.],
                [39355., 59875., 62026., 41893.],
            ],
            [
                [34496., 52928., 55619., 37850.],
                [56678., 86747., 90743., 61619.],
                [67370., 102731., 106727., 72239.],
                [47942., 72998., 75635., 51128.],
            ],
            [
                [40491., 62163., 65340., 44493.],
                [66642., 102057., 106782., 72555.],
                [79278., 120957., 125682., 85119.],
                [56529., 86121., 89244., 60363.],
            ],
            [
                [46486., 71398., 75061., 51136.],
                [76606., 117367., 122821., 83491.],
                [91186., 139183., 144637., 97999.],
                [65116., 99244., 102853., 69598.],
            ],
            [
                [52481., 80633., 84782., 57779.],
                [86570., 132677., 138860., 94427.],
                [103094., 157409., 163592., 110879.],
                [73703., 112367., 116462., 78833.],
            ],
            [
                [58476., 89868., 94503., 64422.],
                [96534., 147987., 154899., 105363.],
                [115002., 175635., 182547., 123759.],
                [82290., 125490., 130071., 88068.],
            ],
            [
                [64471., 99103., 104224., 71065.],
                [106498., 163297., 170938., 116299.],
                [126910., 193861., 201502., 136639.],
                [90877., 138613., 143680., 97303.],
            ],
            [
                [70466., 108338., 113945., 77708.],
                [116462., 178607., 186977., 127235.],
                [138818., 212087., 220457., 149519.],
                [99464., 151736., 157289., 106538.],
            ],
            [
                [76461., 117573., 123666., 84351.],
                [126426., 193917., 203016., 138171.],
                [150726., 230313., 239412., 162399.],
                [108051., 164859., 170898., 115773.],
            ],
            [
                [82456., 126808., 133387., 90994.],
                [136390., 209227., 219055., 149107.],
                [162634., 248539., 258367., 175279.],
                [116638., 177982., 184507., 125008.],
            ],
            [
                [88451., 136043., 143108., 97637.],
                [146354., 224537., 235094., 160043.],
                [174542., 266765., 277322., 188159.],
                [125225., 191105., 198116., 134243.],
            ],
            [
                [94446., 145278., 152829., 104280.],
                [156318., 239847., 251133., 170979.],
                [186450., 284991., 296277., 201039.],
                [133812., 204228., 211725., 143478.],
            ],
        ]]));
    }

    #[test]
    fn test_conv2d_groups() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 2,
            channels_out: 2,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 2,
            height: 5,
            width: 5,
        };

        test.assert_output(TestTensor::from([[
            [[312., 348., 384.], [492., 528., 564.], [672., 708., 744.]],
            [
                [3724., 3841., 3958.],
                [4309., 4426., 4543.],
                [4894., 5011., 5128.],
            ],
        ]]));
    }

    #[test]
    fn test_conv2d_groups_multiple_channels() {
        let test = Conv2dTestCase {
            batch_size: 1,
            channels_in: 4,
            channels_out: 4,
            kernel_size_1: 3,
            kernel_size_2: 3,
            padding_1: 0,
            padding_2: 0,
            stride_1: 1,
            stride_2: 1,
            dilation_1: 1,
            dilation_2: 1,
            groups: 2,
            height: 5,
            width: 5,
        };

        test.assert_output(TestTensor::from([[
            [
                [4035., 4188., 4341.],
                [4800., 4953., 5106.],
                [5565., 5718., 5871.],
            ],
            [
                [10030., 10507., 10984.],
                [12415., 12892., 13369.],
                [14800., 15277., 15754.],
            ],
            [
                [56075., 56876., 57677.],
                [60080., 60881., 61682.],
                [64085., 64886., 65687.],
            ],
            [
                [78270., 79395., 80520.],
                [83895., 85020., 86145.],
                [89520., 90645., 91770.],
            ],
        ]]));
    }

    #[test]
    fn test_conv2d_complex() {
        let test = Conv2dTestCase {
            batch_size: 2,
            channels_in: 3,
            channels_out: 4,
            kernel_size_1: 3,
            kernel_size_2: 2,
            padding_1: 1,
            padding_2: 2,
            stride_1: 2,
            stride_2: 3,
            dilation_1: 1,
            dilation_2: 2,
            groups: 1,
            height: 4,
            width: 5,
        };

        test.assert_output(TestTensor::from([
            [
                [[1845., 3789., 1926.], [3210., 6465., 3228.]],
                [[4276., 9082., 4789.], [8071., 16834., 8737.]],
                [[6707., 14375., 7652.], [12932., 27203., 14246.]],
                [[9138., 19668., 10515.], [17793., 37572., 19755.]],
            ],
            [
                [[5445., 10629., 5166.], [8070., 15645., 7548.]],
                [[14356., 28882., 14509.], [22651., 45454., 22777.]],
                [[23267., 47135., 23852.], [37232., 75263., 38006.]],
                [[32178., 65388., 33195.], [51813., 105072., 53235.]],
            ],
        ]));
    }

    struct Conv2dTestCase {
        batch_size: usize,
        channels_in: usize,
        channels_out: usize,
        kernel_size_1: usize,
        kernel_size_2: usize,
        padding_1: usize,
        padding_2: usize,
        stride_1: usize,
        stride_2: usize,
        dilation_1: usize,
        dilation_2: usize,
        groups: usize,
        height: usize,
        width: usize,
    }

    impl Conv2dTestCase {
        fn assert_output(self, y: TestTensor<4>) {
            let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
            let shape_weight = Shape::new([
                self.channels_out,
                self.channels_in / self.groups,
                self.kernel_size_1,
                self.kernel_size_2,
            ]);
            let device = Default::default();
            let weight = TestTensor::from(
                TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_weight)
                    .into_data(),
            );
            let bias = TestTensor::from(
                TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
            );
            let x = TestTensor::from(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<4, _>(shape_x)
                    .into_data(),
            );
            let output = conv2d(
                x,
                weight,
                Some(bias),
                ConvOptions::new(
                    [self.stride_1, self.stride_2],
                    [self.padding_1, self.padding_2],
                    [self.dilation_1, self.dilation_2],
                    self.groups,
                ),
            );

            y.to_data().assert_approx_eq(&output.into_data(), 3);
        }
    }
}
